Skip to content

Fix dataloading#15

Draft
streichgeorg wants to merge 12 commits intomainfrom
georg/fixes
Draft

Fix dataloading#15
streichgeorg wants to merge 12 commits intomainfrom
georg/fixes

Conversation

@streichgeorg
Copy link
Copy Markdown

  • Add some type coercion (float? -> float64, int? -> int64)
  • Make WSsample.__contains__ look at shards loaded, prevents some errors when working with partially computed features.
  • Add fix for duplicate column names by dropping them (Not sure if this is what we want to do. Had to add it to make some of the pretraining datasets work. Maybe there is a better solution)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unfortunately makes the test as expensive as loading the data with .get('column').

How do you use this downstream?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some datasets where certain transcripts are only computed partially. I could also catch the exception in my code, but I felt it is a bit counter intuitive if __contains__ succeeds, but .get() fails.

Copy link
Copy Markdown
Author

@streichgeorg streichgeorg Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like in most cases you want to do something like

if "col" in sample:
    # Do something with sample["col"]
else:
    # Do something else

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do the more thorough check if include_partial_shards is set on the dataset?


self._filter_dfs[filter_name] = filter_df

rows_satisfying_filter = filter_df.sum().item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rashishhume do you maybe know what's going on here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted this since, it caused a lot of log output at the start of my training jobs. Maybe it could make more sense to log this stuff higher up.

if shard_subsample != 1:
shard_list = rng.sample(shard_list, int(len(shard_list) * shard_subsample))

# TODO: Not sure if we want to drop the columns. I think previously we
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could apply the renaming on SQL as well which would make it consistent with the non-SQL API.

I think most of the issues with duplicate fields were fixed recently when we added the select in this line:

                    df = scan_ipc(shard_path, glob=False).select(fields)

Do you remember which duplicate columns are causing you headaches?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was language_whisper.txt I think, since this is contained in all the shards with Whisper transcripts.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting. We need to fix this then.

)

return exprs, pl.concat(row_merge).select(exprs)
def _common_dtype(col_name: str, a: pl.DataType, b: pl.DataType) -> pl.DataType:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess would be that this is about some shards having a null type because all the samples were None for a column?

Would concat with how="vertical_relaxed" help in this situation? (this would let Polars handle the coercion, hopefully in a sensible way)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that some metrics have shards stored as both float16 and float64.

Copy link
Copy Markdown
Author

@streichgeorg streichgeorg Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try vertical_relaxed (I remember some polars merge mode not handling the issue I was facing, not 100% sure that was vertical_relaxed)

@streichgeorg
Copy link
Copy Markdown
Author

@jpc it seems like you already have better solutions for some of these issues. These we're mostly fixes I added when I was in the process of getting SFT to run. I'm fine with not merging this.

Georg Streich and others added 5 commits January 26, 2026 19:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants